Intro

The following code aims to be a step-by-step walkthrough on how to implement and evaluate a Bayesian Linear Mixed Effects model for simulated website bounce rate data across English counties. The response variable \(y\) is website bounce rate in seconds (bounce rate = time a user spends on a website), and independent variable are age (user age) and location (county).

Load Packages

require(tidyverse)
require(dplyr) # dplyr 1.0.0
require(knitr)
require(kableExtra)
require(gridExtra)
require(grid)
require(ggpubr)
require(modelr)

# Animation library
require(gganimate)
require(gifski)
require(png)

# Mixed Effects Model libraries
require(rstan)
require(brms)
require(tidybayes)
require(bayesplot)
require(loo)

# Prior libraries
require(extraDistr)

Load Data

#cols() suppresses messages
bounce_data <- read_csv("../../data/bounce_rates_sim.csv", 
                        col_types = cols() )  %>% 
  mutate(county = as.factor(county))

kable(bounce_data) %>% 
kable_styling(bootstrap_options = c("striped", "hover", "condensed", "responsive")) %>%
  row_spec(0, background = "#4CAF50", color="#FFF")  %>% 
  scroll_box(width = "100%", height = "200px") 
bounce_time age county
225.8538 54 kent
217.9062 79 kent
220.6601 71 kent
223.4206 47 kent
213.8878 93 kent
206.4577 64 kent
234.4551 91 kent
203.5949 37 essex
203.6985 28 essex
221.7954 36 essex
212.3589 31 essex
215.5649 51 essex
201.6897 53 essex
205.5136 29 essex
211.1856 27 essex
184.2293 39 essex
203.6672 33 essex
210.3584 62 essex
211.9895 28 essex
208.9214 52 essex
213.3661 55 essex
212.4840 15 essex
201.7498 61 essex
221.7837 38 essex
209.9975 40 essex
205.7030 42 essex
196.0435 32 essex
216.0141 55 essex
198.6938 59 essex
184.5071 30 london
182.5263 24 london
180.2908 13 london
176.6328 26 london
172.5186 48 london
179.8916 1 london
161.4922 31 london
182.0156 36 london
194.6150 21 london
183.0910 53 london
184.4475 37 london
165.2503 12 london
181.7137 40 london
191.3859 25 london
171.8971 33 london
190.2567 36 london
184.4317 27 london
168.7719 14 london
178.9858 40 london
192.2939 15 london
178.4500 18 london
182.4974 16 london
182.3821 14 london
171.8979 14 london
165.9008 18 london
167.4110 59 london
185.5440 28 london
191.1776 28 london
189.3372 21 london
180.2835 32 london
172.2707 16 london
183.3813 31 london
195.2366 21 london
176.8045 6 london
170.7791 4 london
159.2062 21 london
170.3089 38 london
198.6933 48 devon
203.9536 70 devon
204.3526 71 devon
201.6434 88 devon
195.3188 110 devon
204.4924 73 devon
194.3613 72 devon
200.4898 88 devon
197.8712 85 devon
210.8970 80 devon
200.8695 83 devon
191.5367 94 devon
204.6859 56 devon
196.3898 67 devon
197.5210 79 devon
201.0421 67 devon
192.0654 26 devon
199.7605 57 devon
201.8048 86 devon
205.8579 73 devon
199.2470 72 devon
204.1690 104 devon
194.4276 72 devon
191.3841 67 devon
200.4391 80 devon
198.0288 29 devon
203.1436 57 devon
202.7748 71 devon
188.1232 88 devon
202.9197 94 devon
200.3105 68 devon
196.7775 70 devon
199.1221 64 devon
208.8129 76 devon
196.6335 74 devon
203.9025 97 devon
197.6037 73 devon
193.3032 69 devon
204.5152 80 devon
197.2597 84 devon
191.6854 55 devon
197.0951 58 devon
200.3260 106 devon
197.5353 43 devon
200.0882 64 devon
200.6170 67 devon
202.1392 79 devon
198.7077 75 devon
193.9549 44 devon
198.8015 66 devon
194.9967 81 devon
204.3312 77 devon
198.1364 86 devon
207.7474 82 devon
203.1465 64 devon
190.3096 84 devon
202.2230 92 devon
192.6323 69 devon
198.9416 59 devon
193.7122 70 devon
202.2707 79 devon
205.0753 114 devon
201.8092 89 devon
200.6674 76 devon
200.3422 63 devon
199.9825 103 devon
196.7411 62 devon
192.8372 79 devon
200.0539 84 devon
197.2865 101 devon
195.5131 78 devon
193.6063 61 devon
202.5980 87 devon
199.4300 62 devon
203.6946 82 devon
215.4648 113 dorset
200.2166 76 dorset
209.1769 57 dorset
175.9648 45 dorset
207.0599 81 dorset
190.7460 64 dorset
207.5366 76 dorset
212.6425 59 dorset
204.9520 84 dorset
209.0750 61 dorset
208.5712 44 dorset
203.9592 48 dorset
202.8222 52 dorset
211.2343 69 dorset
206.1818 64 dorset
194.8650 80 dorset
191.1600 49 dorset
199.9447 84 dorset
194.6530 29 dorset
191.9744 58 dorset
200.1989 65 dorset
209.0733 66 dorset
206.3179 44 dorset
185.9619 84 dorset
200.3820 82 dorset
193.6978 70 dorset
207.5257 61 dorset
195.5070 45 dorset
201.1868 91 dorset
201.5488 53 dorset
179.5727 54 dorset
218.2631 34 dorset
195.9146 60 dorset
199.6272 63 dorset
215.8019 96 dorset
182.6826 68 dorset
190.9447 70 dorset
195.9754 50 dorset
202.3283 41 dorset
201.4353 52 dorset
207.9298 68 dorset
210.2485 65 dorset
208.0484 47 dorset
195.9026 67 dorset
203.2758 52 dorset
193.9061 82 dorset
204.0604 63 dorset
205.3292 70 dorset
200.8077 62 dorset
195.2892 67 dorset
204.6368 65 dorset
202.5677 36 dorset
204.1111 60 dorset
203.9212 78 dorset
198.0681 47 dorset
199.2772 85 dorset
215.9674 64 dorset
198.2211 54 dorset
198.1251 69 dorset
211.6341 52 dorset
180.5467 40 dorset
200.3906 73 dorset
198.6198 61 dorset
209.8897 101 dorset
205.2606 82 dorset
222.4502 83 dorset
195.4427 60 dorset
196.0813 63 dorset
192.5134 78 dorset
217.9686 62 dorset
188.6774 41 dorset
188.7936 46 dorset
212.7729 51 dorset
202.8989 50 dorset
193.9978 39 dorset
200.7135 62 dorset
223.7702 83 dorset
194.9929 59 dorset
213.5304 40 dorset
203.4709 79 dorset
198.8800 91 dorset
205.5893 94 dorset
208.4568 80 dorset
191.7804 46 dorset
207.1614 74 dorset
183.9436 85 dorset
200.5705 66 dorset
226.9220 41 dorset
213.9588 81 dorset
203.7940 66 dorset
208.4021 68 cumbria
213.9375 71 cumbria
200.1463 71 cumbria
210.3203 66 cumbria
211.1401 73 cumbria
198.6172 68 cumbria
200.4856 73 cumbria
203.4624 46 cumbria
209.8694 68 cumbria
212.5938 13 cumbria
213.3855 69 cumbria
218.4301 46 cumbria
201.3730 49 cumbria
207.0558 51 cumbria
201.8000 74 cumbria
207.5178 73 cumbria
214.5728 70 cumbria
199.4696 67 cumbria
210.4172 64 cumbria
204.2227 70 cumbria
205.6621 75 cumbria
218.3393 52 cumbria
217.8684 75 cumbria
207.4168 87 cumbria
191.7129 35 cumbria
212.3685 35 cumbria
207.7391 53 cumbria
202.3393 31 cumbria
207.6240 60 cumbria
211.1069 49 cumbria
216.9618 60 cumbria
204.8985 41 cumbria
208.0298 70 cumbria
208.5042 63 cumbria
202.6250 52 cumbria
209.0654 71 cumbria
207.5364 66 cumbria
204.8710 66 cumbria
206.6085 45 cumbria
215.9537 62 cumbria
204.0067 52 cumbria
204.1681 70 cumbria
205.6398 99 cumbria
215.0105 61 cumbria
222.2287 60 cumbria
211.5478 93 cumbria
201.4563 35 cumbria
218.9502 46 cumbria
216.5111 129 cumbria
205.6643 32 cumbria
208.7487 48 cumbria
207.6356 75 cumbria
210.5432 79 cumbria
212.7470 61 cumbria
203.2712 49 cumbria
210.0344 83 cumbria
204.5591 42 cumbria
202.7044 44 cumbria
209.7123 33 cumbria
212.2624 74 cumbria
212.7655 65 cumbria
211.8085 63 cumbria
196.2341 46 cumbria
204.3112 48 cumbria
212.9128 48 cumbria
212.4191 49 cumbria
201.5805 67 cumbria
208.7316 47 cumbria
208.1921 77 cumbria
220.7695 47 cumbria
214.4266 53 cumbria
200.8323 69 cumbria
205.0210 49 cumbria
208.2647 50 cumbria
208.0480 43 cumbria
211.5435 69 cumbria
203.6096 41 cumbria
217.7464 92 cumbria
209.0103 57 cumbria
208.1187 49 cumbria
211.3844 58 cumbria
216.8920 65 cumbria
212.9242 72 cumbria
214.4708 54 cumbria
208.0446 83 cumbria
210.9401 60 cumbria
206.6836 69 cumbria
212.0873 68 cumbria
214.7157 50 cumbria
211.5558 70 cumbria
206.0166 36 cumbria
196.2511 79 cumbria
203.4005 58 cumbria
209.8707 48 cumbria
217.3593 58 cumbria
211.5424 75 cumbria
214.0327 70 cumbria
208.8590 73 cumbria
201.8902 72 cumbria
205.3669 75 cumbria
215.8747 86 cumbria
204.0299 62 cumbria
205.2632 53 cumbria
210.2801 39 cumbria
221.1225 72 cumbria
204.9159 48 cumbria
215.3467 75 cumbria
206.4351 63 cumbria
204.9950 76 cumbria
209.5784 48 cumbria
210.9453 57 cumbria
206.6750 46 cumbria
177.6245 30 norfolk
182.1104 60 norfolk
175.4903 24 norfolk
177.1250 3 norfolk
185.6625 43 norfolk
182.3939 34 norfolk
186.2690 43 norfolk
156.0129 25 norfolk
182.7219 33 norfolk
170.3454 16 norfolk
177.1850 23 norfolk
189.1116 17 norfolk
177.1237 31 norfolk
184.0938 26 norfolk
178.4260 10 norfolk
174.4067 10 norfolk
183.0203 -1 norfolk
178.3646 48 norfolk
170.4670 33 norfolk
178.2488 37 norfolk
174.0448 51 norfolk
166.9864 48 norfolk
185.9389 39 norfolk
183.6586 13 norfolk
167.9151 33 norfolk
181.7819 28 norfolk
177.6393 19 norfolk
190.4892 42 norfolk
189.3262 24 norfolk
175.1466 33 norfolk
172.5747 41 norfolk
179.4212 36 norfolk
177.7364 15 norfolk
195.6366 60 norfolk
188.7748 22 norfolk
173.7379 40 norfolk
187.2555 34 norfolk
186.4125 39 norfolk
174.1933 40 norfolk
172.5385 34 norfolk
179.8547 23 norfolk
184.6571 28 norfolk
177.3976 39 norfolk
175.2949 45 norfolk
181.5672 33 norfolk
179.0471 51 norfolk
176.1024 20 norfolk
198.8694 19 norfolk
187.0505 26 norfolk
178.3498 23 norfolk
188.5443 51 norfolk
184.8811 25 norfolk
176.2062 23 norfolk
174.8356 43 norfolk
182.0219 27 norfolk
175.9168 9 norfolk
172.4202 61 norfolk
189.6861 46 norfolk
171.9307 23 norfolk
185.9969 17 norfolk
181.1680 10 norfolk
187.3300 14 norfolk
184.4892 31 norfolk
185.3808 26 norfolk
175.1346 53 norfolk
174.8861 45 norfolk
172.2547 28 norfolk
173.2351 38 norfolk
181.3234 50 norfolk
167.4591 18 norfolk
198.7599 31 norfolk
180.9043 36 norfolk
179.2687 48 norfolk
191.1251 34 norfolk
183.8592 49 norfolk
186.4512 36 norfolk
172.4647 45 norfolk
185.5850 16 norfolk
190.6732 47 norfolk
181.4938 18 norfolk
175.3785 41 norfolk
186.9373 45 norfolk
176.1613 34 norfolk
168.1130 25 norfolk
176.6706 44 norfolk
183.8519 20 norfolk
199.1827 57 norfolk
177.7475 50 norfolk
183.4855 25 norfolk
193.8418 28 norfolk
161.0385 29 norfolk
157.2738 23 norfolk
177.4387 16 norfolk
170.8876 42 norfolk
187.2842 25 norfolk
173.5743 27 norfolk
184.4781 27 norfolk
187.9249 27 norfolk
180.6428 -1 norfolk
173.9960 40 norfolk
178.8904 47 norfolk
178.7011 19 norfolk
183.4040 34 norfolk
181.4545 45 norfolk
194.7127 30 norfolk
175.5028 64 norfolk
177.7311 37 norfolk
179.3548 1 norfolk
191.8206 28 norfolk
179.2707 38 norfolk
185.6427 47 norfolk
190.0737 25 norfolk
179.9949 30 norfolk
175.1075 20 norfolk
174.1200 45 norfolk
189.9190 20 norfolk
174.6957 12 norfolk
179.9953 21 norfolk
173.8736 20 norfolk
173.5177 13 norfolk
221.9283 38 cheshire
207.9079 45 cheshire
212.7421 42 cheshire
234.4364 26 cheshire
211.3112 6 cheshire
215.6437 31 cheshire
215.7583 30 cheshire
225.9816 13 cheshire
202.2175 47 cheshire
216.4464 54 cheshire
189.2030 12 cheshire
213.7356 27 cheshire
215.4297 28 cheshire
205.6262 41 cheshire
209.7567 15 cheshire
218.6768 53 cheshire
207.1210 40 cheshire
197.7208 52 cheshire
205.3591 15 cheshire
203.8642 30 cheshire
220.2534 47 cheshire
215.8066 63 cheshire
204.6500 33 cheshire
211.3456 23 cheshire
209.1033 46 cheshire
204.3843 57 cheshire
214.7387 53 cheshire
225.3257 33 cheshire
209.5644 42 cheshire
212.9073 34 cheshire
203.4897 20 cheshire
218.3658 45 cheshire
209.7521 36 cheshire
199.4428 44 cheshire
197.7247 36 cheshire
201.4917 48 cheshire
216.7140 21 cheshire
210.2920 48 cheshire
227.7717 26 cheshire
199.7231 53 cheshire
212.3758 34 cheshire
218.3416 39 cheshire
216.1571 50 cheshire
203.6633 36 cheshire
220.4490 59 cheshire
213.7362 49 cheshire
192.0860 25 cheshire
219.5891 18 cheshire
215.3220 36 cheshire
217.1592 55 cheshire
229.1985 58 cheshire
197.4211 28 cheshire
194.6101 59 cheshire
215.6196 36 cheshire
210.6817 55 cheshire
198.6076 20 cheshire
217.9418 69 cheshire
212.8855 32 cheshire
211.8936 21 cheshire
224.2742 69 cheshire
212.0488 33 cheshire
207.6988 46 cheshire
210.7533 41 cheshire
199.7975 26 cheshire
199.2004 50 cheshire
213.7474 31 cheshire
232.6517 39 cheshire
214.5394 49 cheshire
227.1374 43 cheshire
223.6423 16 cheshire
220.2531 40 cheshire
204.0903 35 cheshire
216.6103 25 cheshire
204.3529 24 cheshire
202.9545 18 cheshire
220.6908 59 cheshire
197.4419 21 cheshire
219.2022 54 cheshire
200.6058 54 cheshire
203.2914 34 cheshire
221.4731 32 cheshire
191.6479 62 cheshire
214.5692 62 cheshire
217.1172 44 cheshire
219.7899 27 cheshire
209.0349 9 cheshire
194.6766 20 cheshire
212.7522 26 cheshire
207.9540 34 cheshire
220.2253 51 cheshire
202.2456 41 cheshire
210.2562 33 cheshire
210.7081 51 cheshire
213.7315 61 cheshire
213.1203 69 cheshire
200.2013 17 cheshire
219.4747 32 cheshire
196.3528 43 cheshire
218.0399 45 cheshire
198.7352 22 cheshire
226.9599 44 cheshire
199.0811 42 cheshire
217.5453 23 cheshire
215.1293 29 cheshire
215.4584 66 cheshire
211.2915 62 cheshire
210.7704 66 cheshire
218.8084 30 cheshire
198.0270 33 cheshire
220.4636 23 cheshire
213.6203 17 cheshire
211.6374 12 cheshire
201.6045 38 cheshire
204.0669 60 cheshire
198.6537 22 cheshire
205.1926 37 cheshire
208.2451 1 cheshire
202.8818 38 cheshire
217.4679 52 cheshire
205.9116 55 cheshire
209.3939 43 cheshire
212.2201 39 cheshire
220.3121 57 cheshire
209.2730 43 cheshire
222.7670 47 cheshire
213.5815 33 cheshire
225.5685 43 cheshire
212.8360 51 cheshire
218.4750 40 cheshire
201.4827 52 cheshire
220.6145 71 cheshire
212.0379 12 cheshire
205.4092 53 cheshire
233.1305 53 cheshire
192.8963 30 cheshire
203.7808 31 cheshire
207.6008 11 cheshire
209.9624 42 cheshire
210.4145 15 cheshire
219.7713 21 cheshire
212.7669 31 cheshire
210.4563 48 cheshire
196.7473 71 cheshire
221.8812 26 cheshire
222.3395 23 cheshire
232.8309 58 cheshire
223.0557 31 cheshire
207.5521 30 cheshire
210.3355 34 cheshire
216.5485 39 cheshire
summary(bounce_data)
##   bounce_time         age              county   
##  Min.   :156.0   Min.   : -1.00   cheshire:150  
##  1st Qu.:190.3   1st Qu.: 32.00   norfolk :120  
##  Median :202.6   Median : 48.00   cumbria :112  
##  Mean   :200.0   Mean   : 49.01   dorset  : 90  
##  3rd Qu.:210.8   3rd Qu.: 66.00   devon   : 75  
##  Max.   :234.5   Max.   :129.00   london  : 37  
##                                   (Other) : 29

Preprocess data

Standardize (center-scale) age variable

# Standardize data
bounce_data <- bounce_data %>% 
  mutate(std_age = scale(age)[,1]) %>% 
  dplyr::relocate(std_age, .after=age)
  
# Example std_age data
summary(bounce_data)
##   bounce_time         age            std_age              county   
##  Min.   :156.0   Min.   : -1.00   Min.   :-2.22640   cheshire:150  
##  1st Qu.:190.3   1st Qu.: 32.00   1st Qu.:-0.75726   norfolk :120  
##  Median :202.6   Median : 48.00   Median :-0.04496   cumbria :112  
##  Mean   :200.0   Mean   : 49.01   Mean   : 0.00000   dorset  : 90  
##  3rd Qu.:210.8   3rd Qu.: 66.00   3rd Qu.: 0.75639   devon   : 75  
##  Max.   :234.5   Max.   :129.00   Max.   : 3.56111   london  : 37  
##                                                      (Other) : 29

Exploratory Data Analysis

Using a ground-up approach to modeling, we combine an EDA + modeling step to guide our modeling approach in an iterative fashion. First build a baseline model (doesn’t have to be Bayesian) where we start simple and build up based on how we see the model performs

Base Model

A simple linear regression for example. The model doesn’t to perform well, for example it has a small \(Adj\ R^2\) value, the residual vs fitted plot shows heteroscedastic variance (i.e. there’s a trend), and when we plot the data with the trendline estimated with this model we see it might not capture group differences leading to an example of Simpson’s paradox.

base_model <- lm(bounce_time ~ std_age, data=bounce_data)

summary(base_model)
## 
## Call:
## lm(formula = bounce_time ~ std_age, data = bounce_data)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -38.962 -10.056   0.458   9.638  39.253 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 199.9773     0.5701 350.792  < 2e-16 ***
## std_age       4.6803     0.5705   8.203 1.39e-15 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 14.11 on 611 degrees of freedom
## Multiple R-squared:  0.09921,    Adjusted R-squared:  0.09774 
## F-statistic: 67.29 on 1 and 611 DF,  p-value: 1.386e-15
ggplot(base_model) + 
  geom_point(aes(x=.fitted, y=.resid)) +
  geom_smooth(aes(x=.fitted, y=.resid), se = FALSE ) +
  labs(x = "Fitted", y="Residuals", title="Variance isn't constant (homoscedastic), we see a fanning down trend" )

ggplot(bounce_data, aes(x = std_age, y = bounce_time, color=county)) + 
  geom_smooth(data=bounce_data, aes(x=std_age, y=bounce_time),
              method="lm", inherit.aes = FALSE, se=FALSE, color="black", size=0.9) +
  geom_point(alpha=0.5) +
  labs(x = "Age (standardized)", y="Bounce rates (secs)", title="The line seems to capture the overall population trend, but might miss group trends (Simpson's paradox)" ) +
  theme(title = element_text(size=8))

Mixed effects

Building on the previous notion that there might be differences between groups (i.e sources of heterogeneity) we now approach the data with a mixed effects model.

\[Y = X\boldsymbol{\beta} + Z\boldsymbol{u} + \epsilon\]

Example setup of LMM

where \(J\) denotes groups, \(X\) is a matrix of all observations, \(\beta\) is vector of fixed effects (i.e. \(\beta_0\) and \(\beta_1\)), \(Z\) is a matrix representation of group-specific observations, \(\gamma\) is a vector of random effects (e.g. \(b_{0,j}\) and/or \(b_{1,j}\)), and \(\epsilon\) is noise vector.

ggplot(bounce_data, aes(x = std_age, y = bounce_time, color=county)) + 
  geom_smooth(data=bounce_data, aes(x=std_age, y=bounce_time),
              method="lm", inherit.aes = FALSE, se=FALSE, color="black", size=0.9)+
  geom_smooth(method="lm") + 
  geom_point(alpha=0.5) +
  labs(x = "Age (standardized)", y="Bounce rates (secs)", title="County intercepts seems to vary the most, while slopes seem very similar" ) +
  theme(title = element_text(size=9))

From the plot above we can quickly see that most heterogeneity between counties comes from varying-intercepts. However, we also see there is a little variability in slopes.

Thus, we will fit both models (varying-intercepts and varying-intercepts, varying-slopes) and assess which one works best.

Bayesian Modeling

The first step in defining a Bayesian model is selecting a set of priors. In our case, we’re defining a Bayesian Linear Mixed Effects model which assumes data is normally distributed and varies across groups.



Prior predictive checks

When selecting prior we would like to select those that are considered to be weakly informative and ideally conjugate. The former, simply refers to selecting a set of priors that would generate plausible data sets similar to our observed data, while the later helps define the model as a generative model. People might confuse weakly informative to be diffuse distributions, however this might generate impossible data sets (e.g. very extreme values) which might be detrimental.

Practically this means we’ll generate various simulated data sets based on set of prior, and check that these provide plausible values.

Random intercept model

Mathematically, this model is defined as follows \[ y_{i,j} \sim N( \beta_0 + b_{0,j} + \beta_1 x_{i,j}, \sigma^2)\\ b_{0,j} \sim N(0, \tau^2)\]

and we need to set priors for the following parameters - \(\sigma\), \(\tau\), \(\beta_0\) and \(\beta_1\)

Here we choose priors for intercept around sample mean of 200 and for the slope of about 1 given what we’ve observed for each group (i.e. intercepts varying across the 200s and population slope is 4). Thus, we set \(\tau \sim N_+(0,10)\) and \(\sigma \sim N_+(0,100)\), while \(\beta_0\sim N(200,1)\) and \(\beta_1 \sim N(4,1)\).

seed_nums <- 1:10
sim_df <- tibble(sim_no = integer(),
                 raw = numeric(),
                 y_sim = numeric())

# Generate a flip-book of simulations
for (i in 1:length(seed_nums)){
  set.seed(i)
  # Variance / std deviation priors
  tau <- abs(rnorm(1, 0, sd = 10))
  sigma <- abs(rnorm(1, 0, sd = 100))
  epsilon <- rnorm(nrow(bounce_data), mean=0, sd=sigma)
  
  # Fixed / Random effects priors
  b0 <- rnorm(8, mean=0, sd=tau)
  beta0 <- rnorm(1, mean=200, sd=1)
  beta1 <- rnorm(1, mean=4, sd=1)

  # Simulated df
  sims <- tibble(sim_no = i,
                 raw = bounce_data$bounce_time,
                 y_sim = beta0 + b0[bounce_data$county] + beta1*bounce_data$std_age + epsilon)
  
  sim_df <- bind_rows(sim_df, sims)
  
}


animate(ggplot(sim_df, aes(x=raw, y=y_sim)) +
  geom_point() +
  ggtitle("Random intercept model: notice changing y-axis, values are plausible yet not identical",
          subtitle = 'Simulation #: {closest_state}') +
  labs(x= "Bounce rates (sec)", y="Simulated rates (secs)") +
  theme(title= element_text(size=7)) +
  transition_states(as.factor(sim_no),
                    transition_length = 1,
                    state_length = 10) +
  view_follow(),
  fps =3, res=120, width=700, height=600)


Example of “Non-informative priors” and where it might go wrong

Here \(\beta_j \sim N(0,100)\) and \(\tau \sim Inv-Gamma(1,100)\) are considered diffuse priors (i.e. very vague)

tau <- sqrt(rinvgamma(1, alpha=1, beta=100))
sigma <-  sqrt(rinvgamma(1, alpha=1, beta=100))
epsilon <- rnorm(nrow(bounce_data), mean=0, sd=sigma)
  
# Fixed / Random effects priors
b0 <- rnorm(8, mean=0, sd=tau)
beta0 <- rnorm(1, mean=0, sd=100)
beta1 <- rnorm(1, mean=0, sd=100)

# Simulated df
sims <- tibble(raw = bounce_data$bounce_time,
               y_sim = beta0 + b0[bounce_data$county] + beta1*bounce_data$std_age + epsilon)
  
  
ggplot(sims, aes(x=raw, y=y_sim)) +
  geom_point(alpha=0.5, color="blue") +
  labs(x="Bounce rates(secs)", y="Simulated rates", 
       title = "We see half of the simulated values are quite improbable (i.e. negative or close to zero)")

Random intercept + slope model

Mathematically, this model is defined as follows \[ y_{i,j} \sim N( \beta_0 + b_{0,j} + (b_{1,j} +\beta_1) x_{i,j}, \sigma^2)\\ \begin{pmatrix}b_{0,j}\\b_{1,j}\end{pmatrix} \sim N\left(0, \begin{pmatrix}\tau_{00} & \tau_{01}\\ \tau_{10} & \tau_{11}\end{pmatrix}\right)\]

and we need to set priors for the following parameters - \(\sigma\), \(\tau_0\), \(\tau_1\), \(\beta_0\) and \(\beta_1\)

We choose similar prior as before and add \(\tau_1 \sim N_+(0,10)\) to the set of priors.

seed_nums <- 1:10
sim_df <- tibble(sim_no = integer(),
                 raw = numeric(),
                 y_sim = numeric())

# Generate a flip-book of simulations
for (i in 1:length(seed_nums)){
  set.seed(i)
  # Variance / std deviation priors
  tau0 <- abs(rnorm(1, 0, sd = 10))
  tau1 <- abs(rnorm(1, 0, sd = 10))
  sigma <- abs(rnorm(1, 0, sd = 100))
  epsilon <- rnorm(nrow(bounce_data), mean=0, sd=sigma)
  
  # Fixed / Random effects priors
  b0 <- rnorm(8, mean=0, sd=tau0)
  b1 <- rnorm(8, mean=0, sd=tau1) 
  beta0 <- rnorm(1, mean=200, sd=1)
  beta1 <- rnorm(1, mean=4, sd=1)

  # Simulated df
  sims <- tibble(sim_no = i,
                 raw = bounce_data$bounce_time,
                 y_sim = beta0 + b0[bounce_data$county] + beta1*bounce_data$std_age + epsilon)
  
  sim_df <- bind_rows(sim_df, sims)
  
}


animate(ggplot(sim_df, aes(x=raw, y=y_sim)) +
  geom_point() +
  ggtitle("Random intercept/slope model: notice changing y-axis, values are plausible yet not identical",
    subtitle = 'Simulation #: {closest_state}') +
  labs(x= "Bounce rates (sec)", y="Simulated rates (secs)") +
  transition_states(as.factor(sim_no),
                    transition_length = 1,
                    state_length = 10) +
  view_follow(),
  fps =3, res=120, width=700, height=600)



Model fitting

Now that we’ve gotten a sense about priors we can now fit Bayesian models using brms which runs on rstan and uses familiar, easy-to-use formula builder (same as lme4).

Here we’ll fit the 3 Bayesian mentioned above: Simple Linear Regression, Random Intercept Model and Random Intercept + Slope Model

Simple Linear Regression Fit

# If you are unsure on how to set up priors, you can use get_priorfunction to see how to specify them
#get_prior(bounce_time ~ std_age, data=bounce_data)
if (!file.exists("models/lin_reg.rds")){
  lin_reg <-  brm(bounce_time ~ std_age, 
                  data=bounce_data,
                  family = gaussian(),
                  prior= c(prior(normal(200, 1), class = Intercept), # intercept prior
                           prior(normal(4, 1), class = b), # fixed effects prior
                           prior(normal(0, 100), class = sigma) # default lower bound is 0 (i.e truncated)
                  ),
                  warmup = 1000, # burn-in
                  iter = 5000, # number of iterations
                  chains = 2,  # number of MCMC chains
                  control = list(adapt_delta = 0.95))
  
  saveRDS(lin_reg, file="models/lin_reg.rds")
} else {
  lin_reg <- readRDS("models/lin_reg.rds")
}

summary(lin_reg)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: bounce_time ~ std_age 
##    Data: bounce_data (Number of observations: 613) 
## Samples: 2 chains, each with iter = 5000; warmup = 1000; thin = 1;
##          total post-warmup samples = 8000
## 
## Population-Level Effects: 
##           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept   199.98      0.50   198.99   200.95 1.00     8283     5501
## std_age       4.52      0.50     3.53     5.50 1.00     7042     5678
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma    14.14      0.40    13.37    14.95 1.00     7043     5302
## 
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).


MCMC diagnostics

We do some early diagnostics on the posterior parameter inferences and the MCMC chains (i.e. traceplots and autocorrelation).

HMC (Hamiltonian Monte Carlo) tends to provide good mixing of MCMC chains and low autocorrelation (still computed here), however there are additional diagnostics (i.e. transition divergence) that can be useful to look at with HMC sampling.

This later part is particularly important if the brms fit output warns us that there seems to be divergent transitions.

# get naming of variables
#get_variables(lin_reg) 

## Divergence
color_scheme_set("darkgray")
d1 <- mcmc_scatter(
  as.matrix(lin_reg),
  pars = c("sigma", "b_Intercept"),
  alpha = 2/3,
  np = nuts_params(lin_reg),
  np_style = scatter_style_np(div_color = "green", div_size = 2.5, div_alpha = 0.75)) 

d2 <- mcmc_scatter(
  as.matrix(lin_reg),
  pars = c("sigma", "b_std_age"),
  alpha = 2/3,
  np = nuts_params(lin_reg),
  np_style = scatter_style_np(div_color = "green", div_size = 2.5, div_alpha = 0.75))

grid.arrange(d1,d2, ncol=2, 
             top = textGrob("Linear Regression HCM Divergence plots: \n No divergence (green) encountered, thus space was explored entirely",gp=gpar(fontsize=12,font=1))) 

# Numeric check of divergence
# lin_reg %>% 
#  spread_draws(divergent__) %>% 
#  mean_hdi()


## Traceplots
color_scheme_set("mix-brightblue-gray")
mcmc_trace(lin_reg,  pars = c("b_Intercept", "b_std_age", "sigma"), n_warmup = 500,
                facet_args = list(ncol = 2, labeller = label_parsed)) +
  labs(x = "Iteration", title="Linear Regression MCMC chain traceplot seems to mix well")

## Autocorrelation
mcmc_acf(lin_reg, pars = c("b_Intercept", "b_std_age", "sigma"), lags = 15) +
  ggtitle("Linear Regression ACF plot",
          subtitle="MCMC ACF for all params seems to have at most lag 1-2 which is desirable")

Random Intercept Fit

Here we fit the random intercept model as defined previously

#get_prior(bounce_time ~ std_age + (1|county), data=bounce_data)

if (!file.exists("models/r_intercept.rds")){
  bayes_rintercept <- brm(bounce_time ~ std_age + (1|county),
                          data = bounce_data,
                          prior = c(prior(normal(200, 1), class = Intercept), # intercept prior
                                    prior(normal(4, 1), class = b), # fixed effects prior
                                    prior(normal(0, 100), class = sigma), # population variance
                                    prior(normal(0, 10), class = sd)), # i.e. tau, group variance
                          warmup = 1000, # burn-in
                          iter = 5000, # number of iterations
                          chains = 2,  # number of MCMC chains
                          control = list(adapt_delta = 0.95)) 

  saveRDS(bayes_rintercept, file= "models/r_intercept.rds")
  
} else {
  bayes_rintercept <- readRDS("models/r_intercept.rds")
}

summary(bayes_rintercept)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: bounce_time ~ std_age + (1 | county) 
##    Data: bounce_data (Number of observations: 613) 
## Samples: 2 chains, each with iter = 5000; warmup = 1000; thin = 1;
##          total post-warmup samples = 8000
## 
## Group-Level Effects: 
## ~county (Number of levels: 8) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept)    13.16      3.02     8.54    20.30 1.00     1316     2439
## 
## Population-Level Effects: 
##           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept   200.03      0.98   198.12   201.99 1.00     3413     4352
## std_age       1.81      0.42     0.98     2.65 1.00     5570     5048
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     7.96      0.23     7.53     8.44 1.00     6101     4887
## 
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).


MCMC diagnostics

We perform similar checks and notice there isn’t any lack of chain mixing, autocorrelation issues or divergence in the HMC sampling

# get naming of variables
#get_variables(bayes_rintercept) 

## Divergence 
color_scheme_set("darkgray")
mcmc_scatter(
  as.matrix(bayes_rintercept),
  pars = vars(contains("devon"), contains("kent")),
  np = nuts_params(bayes_rintercept),
  np_style = scatter_style_np(div_color = "green", div_alpha = 0.75)) +
  labs(x = "Devon rand intercept", y = "Kent rand intercept") +
  ggtitle("HMC Divergence Random Intercept model",
          subtitle = "Example of Devon vs Kent random intercepts, no divergence found")

# Numeric check of divergence
# bayes_rintercept %>% 
#  spread_draws(divergent__) %>% 
#  mean_hdi()

# Traceplots
color_scheme_set("mix-brightblue-gray")
mcmc_trace(bayes_rintercept,  pars = vars(sigma, starts_with("b_")),
           regex_pars = "r_county.*", n_warmup = 500,
           facet_args = list(ncol = 4, labeller = label_parsed)) +
  labs(x = "Iteration", title="Random Intercept MCMC chain traceplots seem to mix well")

## Autocorrelation
mcmc_acf(bayes_rintercept, pars = vars(starts_with("r_county[")),
         lags = 10,
         facet_args = list( labeller = label_parsed)) +
  ggtitle("Random Intercept ACF plot",
          subtitle="MCMC ACF for all params seems to have at most lag 1-2 which is desirable") +
  theme(text = element_text(size = 8))

Random Intercept + Slope Fit

We then fit a random intercept and slopes model as specified above

#get_prior(bounce_time ~ std_age + (1 + std_age |county), data=bounce_data)

if (!file.exists("models/r_slope.rds")){
  
  bayes_rslope <- brm(bounce_time ~ std_age + (1 + std_age|county),
                        data = bounce_data,
                        prior = c(prior(normal(200, 1), class = Intercept), # intercept prior
                                  prior(normal(4, 1), class = b), # fixed effects prior
                                  prior(normal(0, 100), class = sigma), # population variance
                                  prior(normal(0, 10), class = sd, 
                                        group=county, coef="Intercept"), #tau 0
                                  prior(normal(0, 10), class = sd, 
                                        group=county, coef="std_age")),  #tau 1
                        warmup = 1000, # burn-in
                        iter = 5000, # number of iterations
                        chains = 2,   # number of MCMC chains
                        control = list(adapt_delta = 0.95))
  
  saveRDS(bayes_rslope, file="models/r_slope.rds")
} else {
  bayes_rslope <- readRDS("models/r_slope.rds")
}


summary(bayes_rslope)
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: bounce_time ~ std_age + (1 + std_age | county) 
##    Data: bounce_data (Number of observations: 613) 
## Samples: 2 chains, each with iter = 5000; warmup = 1000; thin = 1;
##          total post-warmup samples = 8000
## 
## Group-Level Effects: 
## ~county (Number of levels: 8) 
##                        Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(Intercept)             13.32      3.12     8.48    20.74 1.00     2360
## sd(std_age)                0.69      0.69     0.03     2.54 1.00     3885
## cor(Intercept,std_age)     0.08      0.54    -0.92     0.94 1.00     9282
##                        Tail_ESS
## sd(Intercept)              3346
## sd(std_age)                3789
## cor(Intercept,std_age)     5224
## 
## Population-Level Effects: 
##           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept   200.05      0.98   198.15   201.98 1.00     5845     6188
## std_age       1.94      0.55     0.98     3.16 1.00     6203     3688
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     7.97      0.23     7.53     8.44 1.00    11295     5202
## 
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).


MCMC diagnostics

Similar checks show no issues on chain mixing, autocorrelation or divergence in the HMC sampling

# get_variables(bayes_rslope) 

## Divergence
color_scheme_set("darkgray")
mcmc_scatter(
  as.matrix(bayes_rslope),
  pars = vars(contains("kent")),
  np = nuts_params(bayes_rintercept),
  np_style = scatter_style_np(div_color = "green", div_alpha = 0.75)) +
  labs(x = "Kent rand intercept", y = "Kent rand slope") +
  ggtitle("HMC Divergence Random Intercept model",
          subtitle = "Example of Kent random effects, no divergence found")

# Numeric check of divergence
# bayes_rslope %>% 
#  spread_draws(divergent__) %>% 
#  mean_hdi()


## Traceplots
color_scheme_set("mix-brightblue-gray")

mcmc_trace(bayes_rslope,  pars = vars(contains("std_age")), 
           n_warmup = 500,
           facet_args = list(ncol = 4, labeller = label_parsed)) +
  labs(x = "Iteration") +
  ggtitle("Random Intercept + Slope Model MCMC Traceplot",
          subtitle="Slope (fixed + random) estimates") +
  theme(text=element_text(size=7),
        title = element_text(size=10))

## Autocorrelation
mcmc_acf(bayes_rslope, pars = vars(matches("^r_county.*Intercept")), 
         lags = 10,
         facet_args = list(labeller = label_parsed)) +
  ggtitle("Random Intercept + Slope Model  ACF plot",
          subtitle="MCMC ACF random intercepts seem to have slight autocorrelatione") +
  theme(text=element_text(size=7),
        title = element_text(size=10))

Model Comparison

Posterior Predictive Checks

We use posterior predictive checks to evaluate our models. Here we cover some qualitative and quantitative posterior predictive checks that can help in our model comparison.

For these checks we simulate from the posterior predictive distribution (i.e. sample from the posterior predictive distribution) \[p(\tilde{y} \mid y) = \int p(\tilde{y} | \theta)\; p(\theta | y)\; d \theta\] where y are our observations, \(\tilde{y}\) is new data to be predicted, and \(\theta\) are model parameters (inferred via the Bayesian fit).

Posterior Densities

First, we compare the simulated vs observed densities of bounce_time. We observe that the mixed effects models seem to simulate data that is similar to the empirical data compared to simulated simple linear regression.

# Plot posterior simulated densities 

# Get observed values
y_obs <- bounce_data$bounce_time

# Linear regression
color_scheme_set("red")
pdense1 <- ppc_dens_overlay(y = y_obs,
                           yrep = posterior_predict(lin_reg, nsamples = 100)) +
  labs(x = "Bounce time (s)", y ="Density", title = "Linear Regression") +
  theme(title = element_text(size=10))

# Random Intercept
color_scheme_set("gray" )
pdense2 <- ppc_dens_overlay(y = y_obs,
                           yrep = posterior_predict(bayes_rintercept, nsamples = 100)) +
  labs(x = "Bounce time (s)", title = "Random Intercept") +
  theme(title = element_text(size=10))

# Random Slope
color_scheme_set("teal")
pdense3 <- ppc_dens_overlay(y =y_obs,
                           yrep = posterior_predict(bayes_rslope, nsamples = 100)) +
  labs(x = "Bounce time (s)", title = "Random Intercept + Slope") +
  theme(title = element_text(size=10))
  

# Aggregate and plot pictures side-by-side
ppc_figure <- ggarrange(pdense1, pdense2, pdense3, ncol=3,
                        common.legend = TRUE, legend = "bottom",
                        font.label = list(size=10))

annotate_figure(ppc_figure,
               top = text_grob("Posterior predictive draws (y_rep) for each model vs observed data (y)"))

Posterior Summary Statistics

Next, we look at some summary statistics. Given posterior predictive checks use data twice (i.e. one for fitting and one for checking), we look at summary statistics unrelated to any of the parameters we updated. In our case, we will use median and skewness instead of mean and variance (which we inferred).

Below is an example of why we’d look at unrelated summary statistics. We observe that all model simulations seems to capture the observed mean of bounce_time, however when we look at the skweness we observe the simple regression model simulations don’t reflect the observed data’s skewness.

# Fisher Pearson skew https://www.itl.nist.gov/div898/handbook/eda/section3/eda35b.htm
skew <- function(y){
  n <- length(y)
  dif <- y - mean(y)
  skew_stat <- (sqrt(n-1)/(n-2))*n *(sum(dif^3)/(sum(dif^2)^1.5))
  return(skew_stat)
}

# Helper function to plot PPC test statistic plots
posterior_stat_plot <- function(obs, model, samples=1000, statistic="mean"){
  fig <- ppc_stat(y = obs, 
                   yrep = posterior_predict(model, nsamples = samples),
                   stat = statistic)
  
  return(fig)
}

# Linear Reg Mean and skewness
color_scheme_set("red")
pmean1 <- posterior_stat_plot(y_obs, lin_reg) + labs(y="Linear Regression") +
  theme(legend.text = element_text(size=8), 
        legend.title = element_text(size=8))

pskew1 <- posterior_stat_plot(y_obs, lin_reg, statistic = "skew") +
  theme(legend.text = element_text(size=8),
        legend.title = element_text(size=8))

# Random Intercept Mean and skewness
color_scheme_set("gray")
pmean2 <- posterior_stat_plot(y_obs, bayes_rintercept) + labs(y="Random Intercept")
pskew2 <- posterior_stat_plot(y_obs, bayes_rintercept, statistic = "skew")

color_scheme_set("teal")
pmean3 <- posterior_stat_plot(y_obs, bayes_rslope) +
  labs(x = "Mean", y ="Random Intercept + Slope")

pskew3 <- posterior_stat_plot(y_obs, bayes_rslope, statistic = "skew") +
  labs(x = "Fisher-Pearson Skewness Coeff")

# Random Slope Mean and skewness
ppc_stat1_figure <- ggarrange(pmean1, pskew1, legend="top",
                        font.label = list(size=10))
ppc_stat2_figure <- ggarrange(pmean2, pskew2, legend = "none",
                        font.label = list(size=10))
ppc_stat3_figure <- ggarrange(pmean3, pskew3, legend= "none",
                              font.label = list(size=10))

stat_figure <- ggarrange(ppc_stat1_figure, ppc_stat2_figure, ppc_stat3_figure, nrow=3)

annotate_figure(stat_figure,
               top = text_grob("Posterior test statistics T(y_rep) for each model vs observed data T(y)"))

Similarly, we can look at how each model’s simulations captures each group’s median instead of the mean. Again we observe the mixed effects models capture each county’s median compared to the simple linear regression.

# Compute PPC test statistics per group for each model

# Linear regression
color_scheme_set("red")
med_figure <- ggarrange(lin_reg %>% 
                          posterior_predict(nsamples=500) %>% 
                          ppc_stat_grouped(y = y_obs,
                                           group = bounce_data$county, stat = "median") +
                          ggtitle("Linear Regression") +
                          theme(title = element_text(size=10)),
                        nrow = 1, legend = "none")

annotate_figure(med_figure,
               top = text_grob("Posterior median T(y_rep) for each model vs observed data T(y) across counties"))

# Random Intercept
color_scheme_set("gray")
ggarrange(bayes_rintercept %>% 
            posterior_predict(nsamples=500) %>% 
            ppc_stat_grouped(y = y_obs,
                             group = bounce_data$county, stat = "median") + 
            ggtitle("Random Intercept") +
            theme(title = element_text(size=10)),
          nrow = 1, legend = "none")

# Random slope
color_scheme_set("teal")
ggarrange(bayes_rintercept %>% 
            posterior_predict(nsamples=500) %>% 
            ppc_stat_grouped(y = y_obs,
                             group = bounce_data$county, stat = "median") +
            ggtitle("Random Intercept + Slope") +
            theme(title = element_text(size=10)),
          nrow = 1, legend = "bottom")

LOO Cross-validation (Marginal Predictive Checks)

We can also perform check how well our model performs based on marginal predictive distributions \[p(\tilde{y}_i\mid y)\] rather than using the joint densities as above. This can be useful to find outliers or check overall calibration.

In a Bayesian setting, we can compute leave-on-out cross-validation (LOO CV) via the LOO predictive distribution \[p(y_i \mid y_{-i})\] via a Pareto-smoothing Importance Sampling method (PSIS).

Calibration

First, to check overall model calibration we leverage some probability theory, specifically the concept of probability integral transform. In our case, given we have a continous outcome this means that if the posterior predictive density is the true distribution then its predictive CDF which should follow a Uniform distribution (as \(n -> \infty\)).

From the graphs below, we see that the linear regression does follow this uniform shape while the mixed effects models seem to show some deviations. This is likely to how the data was simulated and should be taken with a grain of salt for this example, but it should help show another diagnostic tool you can use.

# Compute LOO CV for each model, keeping PSIS object for later use
loo1 <- loo(lin_reg, save_psis = TRUE, cores = 2)
lw1 <- weights(loo1$psis_object)

loo2 <- loo(bayes_rintercept, save_psis = TRUE, cores=2)
lw2 <- weights(loo2$psis_object)

loo3 <- loo(bayes_rslope, save_psis = TRUE, cores = 2)
lw3 <- weights(loo3$psis_object)

# Plot predictive CDF to check for uniform distribution
color_scheme_set("red")
pit1 <- ppc_loo_pit_overlay(y_obs, posterior_predict(lin_reg), lw = lw1)

color_scheme_set("gray")
pit2 <- ppc_loo_pit_overlay(y_obs, posterior_predict(bayes_rintercept), lw = lw2)

color_scheme_set("teal")
pit3 <- ppc_loo_pit_overlay(y_obs, posterior_predict(bayes_rslope), lw = lw3)

pit_figure <- ggarrange(pit1, pit2, pit3, ncol=3, 
                        common.legend = TRUE, 
                        legend = "bottom")


annotate_figure(pit_figure,
               top = text_grob("LOO CV Probability Integral Transfrom for each model "))

Even if the linear regression model’s predictive CDF seems to show a good fit, we can see that the mixed effects models LOO predictive intervals capture the true values better than the linear regression model.

# Posterior predictive LOO Interval plots for each model
color_scheme_set("red")
int1 <- ppc_loo_intervals(y_obs, posterior_predict(lin_reg),
                          psis_object = loo1$psis_object, subset = 1:50) +
  theme(axis.title.x = element_blank())
  

color_scheme_set("gray")
int2 <- ppc_loo_intervals(y_obs, posterior_predict(bayes_rintercept), 
                         psis_object = loo2$psis_object, subset = 1:50) +
  theme(axis.title.x = element_blank())

color_scheme_set("teal")
int3 <- ppc_loo_intervals(y_obs, posterior_predict(bayes_rslope),
                             psis_object = loo3$psis_object, subset = 1:50) 

int_figure <- ggarrange(int1, int2, int3, nrow=3, 
                        common.legend = TRUE, 
                        legend = "bottom")

annotate_figure(int_figure,
                top = text_grob("LOO CV Predictive Intervals for first 50 observations for each model "))

ELPD Comparisons, Outliers, and Influential Points

As always we would like to know what the predictive accuracy of our models is given some out-of-sample data (which is unknwown). In Bayesian modeling, we can do so via an expected out-of-sample log predictive density (ELPD) which is also called the expected log predictive density. This basically averages the model’s predictive performance over the distribution of future data.

We compute the overall ELPD for each model via LOO CV log-predictive values, and can use it to compare the expected predictive performance of each model. Here the random intercept model seems to perform best (i.e. 0.0 which is the baseline). The more negative this difference is, the worse the model performs.

# Compare ELPD between models
comp <- loo_compare(loo1, loo2, loo3)

comp
##                  elpd_diff se_diff
## bayes_rintercept    0.0       0.0 
## bayes_rslope       -1.2       0.3 
## lin_reg          -347.8      18.8

We can also use LOO log-predictive to find observations that are difficult to predict amongst models (i.e. outliers or high influence points). In terms of model comparison we can look at which model best captures these observations.

We can do this via ELPD pointwise comparison between models, as well as looking at some diagnostics related to the PSIS method used to compute the LOO log-predictive densities.

Depending on how model values are compared (i.e. Model1 - Model2 vs Model2 - Model1), postivie or negative differences indicate better performing models. In this case, the left figure shows that the random intercept model outperforms the linear regression model overall (positive values) while the right one shows it outperforms the random slope model (negative values) too.

# Obtain pointwise ELPD values
elpd1 <- loo1$pointwise[,"elpd_loo"]
elpd2 <- loo2$pointwise[,"elpd_loo"]
elpd3 <- loo3$pointwise[,"elpd_loo"]

# Build differences dataframe
elpd_df <- tibble(county = bounce_data$county,
                  diff12 = elpd2 - elpd1,
                  diff23 = elpd3 - elpd2) %>% 
  mutate(idx = 1:n())

# Plot each difference individually
pw_elpd1 <- ggplot(elpd_df, aes(x = idx, y = diff12, color = county)) +
  geom_point(alpha=0.7) +
  geom_hline(aes(yintercept=0)) +
  theme_bw() + 
  labs(x = "Index", y = "ELPD Difference",
       title = "Random Intercept - Linear Regression ", color="County") +
  theme(title = element_text(size=8))

pw_elpd2 <- ggplot(elpd_df, aes(x = idx, y = diff23, color = county)) +
  geom_point(alpha=0.7) +
  geom_hline(aes(yintercept=0)) +
  theme_bw()+
  labs(x = "Index",
       title = "Random Intercept/Slope - Random Intercept")+
  theme(axis.title.y = element_blank(),
        title = element_text(size=8))

# Group figures from above
annotate_figure(ggarrange(pw_elpd1, pw_elpd2, common.legend = TRUE, legend = "bottom"),
                top = text_grob("LOO ELPD Pointwise Model comparisons, Random Intercept model performs best"))

We can look at some PSIS diagnostics, to get a sense of why the random intercept outperforms the random slope. Below are the values for the Pareto k \(\hat{k}\) diagnostic which is related to the variance of this type of importance sampler.

Briefly, \(\hat{k} < 0.5\) indicates the importance weights have finite variance and the model has converged which translated to the model having low RMSE. If instead \(0.5 \leq \hat{k} < 0.7\) then PSIS importance weights have finite variance, however the convergence slows as we increase the value. If the value is \(\hat{k} > 0.7\), then we can’t assure convergence and the estimated value aren’t reliable.

In our case, none are greater than 0.5 however we can notice that the random intercept model resolves these values better than the random slope (\(\hat{k}_{intercept} < \hat{k}_{slope}\) and are closely distributed around \(y=0\))

# Get khat values
k_rintercept <- loo2$psis_object$diagnostics$pareto_k
k_rslope <- loo3$psis_object$diagnostics$pareto_k

# Plot values
tibble(idx = seq_along(k_rintercept), 
       r_intercept = k_rintercept,
       r_slope = k_rslope) %>% 
  pivot_longer(cols = -idx) %>% 
  ggplot(., aes(x = idx, y = value, color = name)) +
  geom_point(alpha=0.5) +
  geom_hline(aes(yintercept=0)) +
  theme_bw() +
  ylim(-1,1) +
  facet_wrap(~ name) +
  labs(y = expression(hat(k)), x ="Observation Index") +
  ggtitle("Pareto-k diagnostic (PSIS diagnostic), no influence points (k<0.7)",
          subtitle = "Random intercept models resolves left observation better")

Final Model

From the diagnostics above it is clear that the best model is the random intercept model, which I summarize its effects below

#Posterior estimates
post_beta_means <- bayes_rintercept %>% 
  recover_types(bayes_rintercept) %>% 
  gather_draws(b_Intercept, b_std_age, sigma,
               sd_county__Intercept) %>% 
  median_hdi() %>% 
  dplyr::select(1:4)

kable(post_beta_means, col.names = c("Variable", "Estimate", "Q2.5", "Q97.5"),
      caption = "Random Intercept Fixed Effects Posterior Median and 95% HPDI ") %>%
  kable_styling(bootstrap_options = c("striped", "hover", "condensed", "responsive")) %>%
  row_spec(0, background = "#4CAF50", color="#FFF")  %>% 
  scroll_box(width = "100%", height = "200px") 
Random Intercept Fixed Effects Posterior Median and 95% HPDI
Variable Estimate Q2.5 Q97.5
b_Intercept 200.000734 198.2277541 202.084766
b_std_age 1.804292 0.9809103 2.649364
sd_county__Intercept 12.712010 8.0292149 19.369187
sigma 7.953067 7.5268567 8.424336
# Random effect estimates + 95% HDPI
bayes_rintercept %>% 
  spread_draws(r_county[county]) %>%
  ggplot(aes(y = fct_rev(county), x = r_county, fill=fct_rev(county))) +
  geom_vline(aes(xintercept=0), alpha=0.5, linetype=2) +
  stat_halfeyeh(.width = c(0.5, .95), point_interval= median_hdi,
                alpha=0.8) +
  labs(x = "Estimate Values", y = "County" ) +
  ggtitle("Random Intercept Estimates per county",
          subtitle = "Median + 95% HPD Interval ") +
  theme_bw() +
  theme(legend.position = "None")

We can also visualize the expected posterior predicted lines and associated uncertainty using add_fitted_draws

bounce_data %>% 
  group_by(county) %>% 
  data_grid(std_age = seq_range(std_age, n = 101)) %>% 
  add_fitted_draws(bayes_rintercept, n=100) %>%
  ggplot(aes(x = std_age, y = bounce_time, color = ordered(county))) +
  stat_lineribbon(aes(y = .value), alpha=0.9) +
  geom_point(data = bounce_data, alpha=0.7) +
  theme_bw() +
  scale_fill_brewer(palette = "Greys") +
  scale_color_brewer(palette = "Dark2", guide="none") +
  facet_wrap(~county, scales = "free") +
  theme(legend.position = "bottom",
        legend.text = element_text(size=8),
        legend.title = element_text(size=8),
        axis.title = element_text(size=8),
        axis.text = element_text(size=8)) +
  labs(fill= "Confidence level", x = "Age (standardize)", y="Bounce time (s)") +
  ggtitle("Expected Posterior predicted fits + 95% CI",
          subtitle = "Counties with less samples exhibit higher uncertainty")

References

  1. Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., & Rubin, D. B. (2013). Bayesian data analysis. CRC press.

  2. Gabry, Jonah, et al. “Visualization in Bayesian workflow.” Journal of the Royal Statistical Society: Series A (Statistics in Society) 182.2 (2019): 389–402.

  3. BRMS. https://paul-buerkner.github.io/brms/